{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4-2、张量的数学运算\n",
    "\n",
    "张量的操作主要包括张量的结构操作和张量的数学运算。\n",
    "\n",
    "张量结构操作诸如：张量创建，索引切片，维度变换，合并分割。\n",
    "\n",
    "张量数学运算主要有：标量运算，向量运算，矩阵运算。另外我们会介绍张量运算的广播机制。\n",
    "\n",
    "本篇我们介绍张量的数学运算。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一、标量运算\n",
    "\n",
    "张量的数学运算符可以分为标量运算符、向量运算符、以及矩阵运算符。\n",
    "\n",
    "加减乘除乘方，以及三角函数，指数，对数等常见函数，逻辑比较运算符等都是标量运算符。\n",
    "\n",
    "标量运算符的特点是对张量实施逐元素运算。\n",
    "\n",
    "有些标量运算符对常用的数学运算符进行了重载。并且支持类似numpy的广播特性。\n",
    "\n",
    "许多标量运算符都在 tf.math模块下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf \n",
    "import numpy as np "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 6.,  8.],\n",
       "       [ 4., 12.]], dtype=float32)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = tf.constant([[1.0,2],[-3,4.0]])\n",
    "b = tf.constant([[5.0,6],[7.0,8.0]])\n",
    "a+b  #运算符重载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ -4.,  -4.],\n",
       "       [-10.,  -4.]], dtype=float32)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a-b "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[  5.,  12.],\n",
       "       [-21.,  32.]], dtype=float32)>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a*b "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 0.2       ,  0.33333334],\n",
       "       [-0.42857143,  0.5       ]], dtype=float32)>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a/b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 1.,  4.],\n",
       "       [ 9., 16.]], dtype=float32)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[1.       , 1.4142135],\n",
       "       [      nan, 2.       ]], dtype=float32)>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a**(0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 1.,  2.],\n",
       "       [-0.,  1.]], dtype=float32)>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a%3 #mod的运算符重载，等价于m = tf.math.mod(a,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 0.,  0.],\n",
       "       [-1.,  1.]], dtype=float32)>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a//3  #地板除法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=bool, numpy=\n",
       "array([[False,  True],\n",
       "       [False,  True]])>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a>=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=bool, numpy=\n",
       "array([[False,  True],\n",
       "       [False, False]])>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a>=2)&(a<=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=bool, numpy=\n",
       "array([[ True,  True],\n",
       "       [ True,  True]])>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a>=2)|(a<=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=bool, numpy=\n",
       "array([[False, False],\n",
       "       [False, False]])>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a==5 #tf.equal(a,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[1.       , 1.4142135],\n",
       "       [      nan, 2.       ]], dtype=float32)>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sqrt(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2,), dtype=float32, numpy=array([12., 21.], dtype=float32)>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = tf.constant([1.0,8.0])\n",
    "b = tf.constant([5.0,6.0])\n",
    "c = tf.constant([6.0,7.0])\n",
    "tf.add_n([a,b,c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5 8]\n"
     ]
    }
   ],
   "source": [
    "tf.print(tf.maximum(a,b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 6]\n"
     ]
    }
   ],
   "source": [
    "tf.print(tf.minimum(a,b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二、向量运算\n",
    "\n",
    "向量运算符只在一个特定轴上运算，将一个向量映射到一个标量或者另外一个向量。\n",
    "许多向量运算符都以reduce开头。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3 ... 7 8 9]\n"
     ]
    }
   ],
   "source": [
    "#向量reduce\n",
    "a = tf.range(1,10)\n",
    "tf.print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "求和： 45\n",
      "均值： 5\n",
      "最大值： 9\n",
      "最小值： 1\n",
      "乘积： 362880\n"
     ]
    }
   ],
   "source": [
    "tf.print(\"求和：\", tf.reduce_sum(a))\n",
    "tf.print(\"均值：\", tf.reduce_mean(a))\n",
    "tf.print(\"最大值：\", tf.reduce_max(a))\n",
    "tf.print(\"最小值：\", tf.reduce_min(a))\n",
    "# 函数计算一个张量的各个维度上元素的乘积（张量沿着某一维度计算乘积）\n",
    "tf.print(\"乘积：\", tf.reduce_prod(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2 3]\n",
      " [4 5 6]\n",
      " [7 8 9]]\n",
      "[[6]\n",
      " [15]\n",
      " [24]]\n",
      "[6 15 24]\n",
      "[[12 15 18]]\n"
     ]
    }
   ],
   "source": [
    "# 张量制定维度进行 reduce\n",
    "b = tf.reshape(a, (3, 3))\n",
    "tf.print(b)\n",
    "tf.print(tf.reduce_sum(b, axis=1, keepdims=True)) # 行方向求和\n",
    "tf.print(tf.reduce_sum(b, axis=1, keepdims=False)) \n",
    "tf.print(tf.reduce_sum(b, axis=0, keepdims=True)) # 列方向求和"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "#bool类型的reduce\n",
    "p = tf.constant([True,False,False])\n",
    "q = tf.constant([False,False,True])\n",
    "tf.print(tf.reduce_all(p))\n",
    "tf.print(tf.reduce_any(q))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45\n"
     ]
    }
   ],
   "source": [
    "#利用tf.foldr实现tf.reduce_sum\n",
    "s = tf.foldr(lambda a, b:a+b, tf.range(10))\n",
    "tf.print(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* cumsum()累积连加\n",
    "* cumprod()累积连乘"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 3 6 ... 28 36 45]\n",
      "[1 2 6 ... 5040 40320 362880]\n"
     ]
    }
   ],
   "source": [
    "#cum扫描累积\n",
    "a = tf.range(1,10)\n",
    "tf.print(tf.math.cumsum(a)) # 沿着tensor（张量）x的某一个维度axis，计算累积和。\n",
    "tf.print(tf.math.cumprod(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3 ... 7 8 9]\n",
      "8\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "#arg最大最小值索引\n",
    "a = tf.range(1,10)\n",
    "tf.print(a)\n",
    "tf.print(tf.argmax(a))\n",
    "tf.print(tf.argmin(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[8 7 5]\n",
      "[5 2 3]\n"
     ]
    }
   ],
   "source": [
    "# tf.math.top_k 用于对张量的排序\n",
    "a = tf.constant([1, 3, 7, 5, 4, 8])\n",
    "\n",
    "values, indexs = tf.math.top_k(a, 3, sorted=True)\n",
    "tf.print(values)\n",
    "tf.print(indexs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三、矩阵运算\n",
    "\n",
    "矩阵必须是二维的。类似tf.constant([1,2,3])这样的不是矩阵。\n",
    "\n",
    "矩阵运算包括：矩阵乘法，矩阵转置，矩阵逆，矩阵求迹，矩阵范数，矩阵行列式，矩阵求特征值，矩阵分解等运算。\n",
    "\n",
    "除了一些常用的运算外，大部分和矩阵有关的运算都在tf.linalg子包中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[2 4]\n",
      " [6 8]]\n",
      "\n",
      "[[2 4]\n",
      " [6 8]]\n"
     ]
    }
   ],
   "source": [
    "#矩阵乘法\n",
    "a = tf.constant([[1,2],[3,4]])\n",
    "b = tf.constant([[2,0],[0,2]])\n",
    "tf.print(a@b)  #等价于tf.matmul(a,b)\n",
    "tf.print(\"\")\n",
    "tf.print(tf.matmul(a, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2]\n",
      " [3 4]]\n",
      "[[1 3]\n",
      " [2 4]]\n"
     ]
    }
   ],
   "source": [
    "#矩阵转置\n",
    "a = tf.constant([[1.0,2],[3,4]])\n",
    "tf.print(a)\n",
    "tf.print(tf.transpose(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2]\n",
      " [3 4]]\n",
      "\n",
      "[[-2.00000024 1.00000012]\n",
      " [1.50000012 -0.50000006]]\n"
     ]
    }
   ],
   "source": [
    "#矩阵逆，必须为tf.float32或tf.double类型\n",
    "a = tf.constant([[1.0,2], [3.0,4]], dtype = tf.float32)\n",
    "tf.print(a)\n",
    "tf.print(\"\")\n",
    "tf.print(tf.linalg.inv(a))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "矩阵的迹（trace）表示矩阵 AAA 主对角线所有元素的和，即\n",
    "tr(A)=a11+a22+⋯+ann"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2]\n",
      " [3 4]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=5.0>"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#矩阵求trace \n",
    "a = tf.constant([[1.0,2],[3,4]])\n",
    "tf.print(a)\n",
    "tf.linalg.trace(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=5.477226>"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#矩阵求范数\n",
    "a = tf.constant([[1.0,2],[3,4]])\n",
    "tf.linalg.norm(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-2\n"
     ]
    }
   ],
   "source": [
    "# 矩阵行列式\n",
    "a = tf.constant([[1.0,2],[3,4]])\n",
    "tf.print(tf.linalg.det(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.854102075 5.85410213]\n"
     ]
    }
   ],
   "source": [
    "# 矩阵特征值\n",
    "tf.print(tf.linalg.eigvalsh(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.316227794 -0.948683321]\n",
      " [-0.948683321 0.316227734]]\n",
      "[[-3.1622777 -4.4271884]\n",
      " [0 -0.632455349]]\n",
      "[[1.00000012 1.99999976]\n",
      " [3 4]]\n"
     ]
    }
   ],
   "source": [
    "#矩阵qr分解\n",
    "a = tf.constant([[1.0,2.0],[3.0,4.0]],dtype = tf.float32)\n",
    "q, r = tf.linalg.qr(a)\n",
    "tf.print(q)\n",
    "tf.print(r)\n",
    "tf.print(q@r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5.46498537 0.365966141]\n",
      "对角矩阵: [[5.46498537 0]\n",
      " [0 0.365966141]]\n",
      "[[0.404553503 -0.914514303]\n",
      " [0.914514303 0.404553503]]\n",
      "[[0.576048374 0.817415595]\n",
      " [0.817415595 -0.576048374]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[0.9999996, 1.9999996],\n",
       "       [2.9999998, 4.       ]], dtype=float32)>"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#矩阵svd分解\n",
    "a  = tf.constant([[1.0,2.0],[3.0,4.0]],dtype = tf.float32)\n",
    "v, s, d = tf.linalg.svd(a)\n",
    "tf.print(v)\n",
    "tf.print(\"对角矩阵:\",tf.linalg.diag(v))\n",
    "tf.print(s)\n",
    "tf.print(d)\n",
    "tf.matmul(tf.matmul(s, tf.linalg.diag(v)), d)\n",
    "#利用svd分解可以在TensorFlow中实现主成分分析降维"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 四，广播机制\n",
    "\n",
    "TensorFlow的广播规则和numpy是一样的:\n",
    "\n",
    "* 1、如果张量的维度不同，将维度较小的张量进行扩展，直到两个张量的维度都一样。\n",
    "* 2、如果两个张量在某个维度上的长度是相同的，或者其中一个张量在该维度上的长度为1，那么我们就说这两个张量在该维度上是相容的。\n",
    "* 3、如果两个张量在所有维度上都是相容的，它们就能使用广播。\n",
    "* 4、广播之后，每个维度的长度将取两个张量在该维度长度的较大值。\n",
    "* 5、在任何一个维度上，如果一个张量的长度为1，另一个张量长度大于1，那么在该维度上，就好像是对第一个张量进行了复制。\n",
    "\n",
    "tf.broadcast_to 以显式的方式按照广播机制扩展张量的维度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3]\n",
      " \n",
      "[[0 0 0]\n",
      " [1 1 1]\n",
      " [2 2 2]]\n",
      " \n",
      "a + b:\n",
      "[[1 2 3]\n",
      " [2 3 4]\n",
      " [3 4 5]]\n"
     ]
    }
   ],
   "source": [
    "a = tf.constant([1,2,3])\n",
    "b = tf.constant([[0,0,0],[1,1,1],[2,2,2]])\n",
    "tf.print(a)\n",
    "tf.print(\" \")\n",
    "tf.print(b)\n",
    "tf.print(\" \")\n",
    "tf.print(\"a + b:\")\n",
    "tf.print(a + b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 3), dtype=int32, numpy=\n",
       "array([[1, 2, 3],\n",
       "       [1, 2, 3],\n",
       "       [1, 2, 3]], dtype=int32)>"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.broadcast_to(a, b.shape) # 以显式的方式按照广播机制扩展张量的维度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([3, 3])"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算广播后计算结果的形状，静态形状，TensorShape类型参数\n",
    "tf.broadcast_static_shape(a.shape,b.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2,), dtype=int32, numpy=array([3, 3], dtype=int32)>"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算广播后计算结果的形状，动态形状，Tensor类型参数\n",
    "c = tf.constant([1,2,3])\n",
    "d = tf.constant([[1],[2],[3]])\n",
    "tf.broadcast_dynamic_shape(tf.shape(c), tf.shape(d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 3), dtype=int32, numpy=\n",
       "array([[2, 3, 4],\n",
       "       [3, 4, 5],\n",
       "       [4, 5, 6]], dtype=int32)>"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#广播效果\n",
    "c+d #等价于 tf.broadcast_to(c,[3,3]) + tf.broadcast_to(d,[3,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 3), dtype=int32, numpy=\n",
       "array([[2, 3, 4],\n",
       "       [3, 4, 5],\n",
       "       [4, 5, 6]], dtype=int32)>"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.broadcast_to(c,[3,3]) + tf.broadcast_to(d,[3,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
